{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d53b656e-aa3b-433d-8131-02083a985090",
   "metadata": {},
   "source": [
    "# AutoMM Problem Types And Metrics\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/problem_types_and_metrics.ipynb)\n",
    "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/advanced_topics/problem_types_and_metrics.ipynb)\n",
    "\n",
    "\n",
    "AutoGluon Multimodal supports various problem types for different machine learning tasks. In this tutorial, we will introduce each problem type, its supported modalities, and evaluation metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aa00faab-252f-44c9-b8f7-57131aa8251c",
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "!pip install autogluon.multimodal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b02ccb8d-7f3f-4152-a526-ddcbf90e1a1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bf91ef1",
   "metadata": {},
   "source": [
    "Lets first write a helper function to print problem type information in a formatted way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c7a2c90e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.multimodal.constants import *\n",
    "from autogluon.multimodal.problem_types import PROBLEM_TYPES_REG\n",
    "\n",
    "def print_problem_type_info(name: str, props):\n",
    "    \"\"\"Helper function to print problem type information\"\"\"\n",
    "    print(f\"\\n=== {name} ===\")\n",
    "    \n",
    "    print(\"\\nSupported Input Modalities:\")\n",
    "    # Convert set to sorted list for complete display\n",
    "    for modality in sorted(list(props.supported_modality_type)):\n",
    "        print(f\"- {modality}\")\n",
    "        \n",
    "    if hasattr(props, 'supported_evaluation_metrics') and props.supported_evaluation_metrics:\n",
    "        print(\"\\nEvaluation Metrics:\")\n",
    "        # Convert to sorted list to ensure complete and consistent display\n",
    "        for metric in sorted(list(props.supported_evaluation_metrics)):\n",
    "            if metric == props.fallback_evaluation_metric:\n",
    "                print(f\"- {metric} (default)\")\n",
    "            else:\n",
    "                print(f\"- {metric}\")\n",
    "                \n",
    "    if hasattr(props, 'support_zero_shot'):\n",
    "        print(\"\\nSpecial Capabilities:\")\n",
    "        print(f\"- Zero-shot prediction: {'Supported' if props.support_zero_shot else 'Not supported'}\")\n",
    "        print(f\"- Training support: {'Supported' if props.support_fit else 'Not supported'}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55496174-9331-484c-aafe-ce0659123def",
   "metadata": {},
   "source": [
    "## Classification\n",
    "\n",
    "AutoGluon supports two types of classification:\n",
    "\n",
    "- Binary Classification (2 classes)\n",
    "- Multiclass Classification (3+ classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c4d9f14f-7948-4679-ad4c-3413d8f18e4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Binary Classification ===\n",
      "\n",
      "Supported Input Modalities:\n",
      "- categorical\n",
      "- image\n",
      "- image_base64_str\n",
      "- image_bytearray\n",
      "- numerical\n",
      "- text\n",
      "\n",
      "Evaluation Metrics:\n",
      "- acc\n",
      "- accuracy\n",
      "- average_precision\n",
      "- balanced_accuracy\n",
      "- f1\n",
      "- f1_macro\n",
      "- f1_micro\n",
      "- f1_weighted\n",
      "- log_loss\n",
      "- mcc\n",
      "- nll\n",
      "- pac\n",
      "- pac_score\n",
      "- precision\n",
      "- precision_macro\n",
      "- precision_micro\n",
      "- precision_weighted\n",
      "- quadratic_kappa\n",
      "- recall\n",
      "- recall_macro\n",
      "- recall_micro\n",
      "- recall_weighted\n",
      "- roc_auc (default)\n",
      "\n",
      "Special Capabilities:\n",
      "- Zero-shot prediction: Not supported\n",
      "- Training support: Supported\n",
      "\n",
      "=== Multiclass Classification ===\n",
      "\n",
      "Supported Input Modalities:\n",
      "- categorical\n",
      "- image\n",
      "- image_base64_str\n",
      "- image_bytearray\n",
      "- numerical\n",
      "- text\n",
      "\n",
      "Evaluation Metrics:\n",
      "- acc\n",
      "- accuracy (default)\n",
      "- balanced_accuracy\n",
      "- f1_macro\n",
      "- f1_micro\n",
      "- f1_weighted\n",
      "- log_loss\n",
      "- mcc\n",
      "- nll\n",
      "- pac\n",
      "- pac_score\n",
      "- precision_macro\n",
      "- precision_micro\n",
      "- precision_weighted\n",
      "- quadratic_kappa\n",
      "- recall_macro\n",
      "- recall_micro\n",
      "- recall_weighted\n",
      "- roc_auc_ovo\n",
      "- roc_auc_ovo_macro\n",
      "- roc_auc_ovo_weighted\n",
      "- roc_auc_ovr\n",
      "- roc_auc_ovr_macro\n",
      "- roc_auc_ovr_micro\n",
      "- roc_auc_ovr_weighted\n",
      "\n",
      "Special Capabilities:\n",
      "- Zero-shot prediction: Not supported\n",
      "- Training support: Supported\n"
     ]
    }
   ],
   "source": [
    "# Classification\n",
    "binary_props = PROBLEM_TYPES_REG.get(BINARY)\n",
    "multiclass_props = PROBLEM_TYPES_REG.get(MULTICLASS)\n",
    "print_problem_type_info(\"Binary Classification\", binary_props)\n",
    "print_problem_type_info(\"Multiclass Classification\", multiclass_props)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c85b8b89",
   "metadata": {},
   "source": [
    "## Regression\n",
    "\n",
    "Regression problems support predicting numerical values from various input modalities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "adf13536",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Regression ===\n",
      "\n",
      "Supported Input Modalities:\n",
      "- categorical\n",
      "- image\n",
      "- image_base64_str\n",
      "- image_bytearray\n",
      "- numerical\n",
      "- text\n",
      "\n",
      "Evaluation Metrics:\n",
      "- mae\n",
      "- mape\n",
      "- mean_absolute_error\n",
      "- mean_absolute_percentage_error\n",
      "- mean_squared_error\n",
      "- median_absolute_error\n",
      "- mse\n",
      "- pearsonr\n",
      "- r2\n",
      "- rmse (default)\n",
      "- root_mean_squared_error\n",
      "- smape\n",
      "- spearmanr\n",
      "- symmetric_mean_absolute_percentage_error\n",
      "\n",
      "Special Capabilities:\n",
      "- Zero-shot prediction: Not supported\n",
      "- Training support: Supported\n"
     ]
    }
   ],
   "source": [
    "# Regression\n",
    "regression_props = PROBLEM_TYPES_REG.get(REGRESSION)\n",
    "print_problem_type_info(\"Regression\", regression_props)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81dd66b6",
   "metadata": {},
   "source": [
    "## Object Detection\n",
    "\n",
    "Object detection identifies and localizes objects in images using bounding boxes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0a3b0d46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Object Detection ===\n",
      "\n",
      "Supported Input Modalities:\n",
      "- image\n",
      "\n",
      "Evaluation Metrics:\n",
      "- map (default)\n",
      "- map_50\n",
      "- map_75\n",
      "- map_large\n",
      "- map_medium\n",
      "- map_small\n",
      "- mar_1\n",
      "- mar_10\n",
      "- mar_100\n",
      "- mar_large\n",
      "- mar_medium\n",
      "- mar_small\n",
      "- mean_average_precision\n",
      "\n",
      "Special Capabilities:\n",
      "- Zero-shot prediction: Supported\n",
      "- Training support: Supported\n"
     ]
    }
   ],
   "source": [
    "# Object Detection\n",
    "object_detection_props = PROBLEM_TYPES_REG.get(OBJECT_DETECTION)\n",
    "print_problem_type_info(\"Object Detection\", object_detection_props)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb33181a",
   "metadata": {},
   "source": [
    "## Semantic Segmentation\n",
    "\n",
    "Semantic segmentation performs pixel-level classification of images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bd95b65f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Semantic Segmentation ===\n",
      "\n",
      "Supported Input Modalities:\n",
      "- image\n",
      "\n",
      "Evaluation Metrics:\n",
      "- ber\n",
      "- iou (default)\n",
      "- sm\n",
      "\n",
      "Special Capabilities:\n",
      "- Zero-shot prediction: Supported\n",
      "- Training support: Supported\n"
     ]
    }
   ],
   "source": [
    "# Semantic Segmentation\n",
    "segmentation_props = PROBLEM_TYPES_REG.get(SEMANTIC_SEGMENTATION)\n",
    "print_problem_type_info(\"Semantic Segmentation\", segmentation_props)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3767d180",
   "metadata": {},
   "source": [
    "## Similarity Matching Problems\n",
    "\n",
    "AutoGluon supports three types of similarity matching:\n",
    "\n",
    "- Text-to-Text Similarity\n",
    "- Image-to-Image Similarity\n",
    "- Image-to-Text Similarity\n",
    "\n",
    "Check [Matching Tutorials](../semantic_matching/index.md) for more details"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0d62968c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Similarity Matching ===\n",
      "\n",
      "Text Similarity:\n",
      "Input Requirements:\n",
      "- text\n",
      "Zero-shot prediction: Supported\n",
      "\n",
      "Image Similarity:\n",
      "Input Requirements:\n",
      "- image\n",
      "Zero-shot prediction: Supported\n",
      "\n",
      "Image-Text Similarity:\n",
      "Input Requirements:\n",
      "- text\n",
      "- image\n",
      "Zero-shot prediction: Supported\n"
     ]
    }
   ],
   "source": [
    "similarity_types = [\n",
    "    (TEXT_SIMILARITY, \"Text Similarity\"),\n",
    "    (IMAGE_SIMILARITY, \"Image Similarity\"),\n",
    "    (IMAGE_TEXT_SIMILARITY, \"Image-Text Similarity\")\n",
    "]\n",
    "\n",
    "print(\"\\n=== Similarity Matching ===\")\n",
    "for type_key, type_name in similarity_types:\n",
    "    props = PROBLEM_TYPES_REG.get(type_key)\n",
    "    print(f\"\\n{type_name}:\")\n",
    "    print(\"Input Requirements:\")\n",
    "    for modality in props.supported_modality_type:\n",
    "        print(f\"- {modality}\")\n",
    "    print(f\"Zero-shot prediction: {'Supported' if props.support_zero_shot else 'Not supported'}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b63070e0",
   "metadata": {},
   "source": [
    "## Named Entity Recognition (NER)\n",
    "\n",
    "NER identifies and classifies named entities (like person names, locations, organizations) in text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "63595694",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Named Entity Recognition ===\n",
      "\n",
      "Supported Input Modalities:\n",
      "- categorical\n",
      "- image\n",
      "- numerical\n",
      "- text\n",
      "- text_ner\n",
      "\n",
      "Evaluation Metrics:\n",
      "- ner_token_f1\n",
      "- overall_f1 (default)\n",
      "\n",
      "Special Capabilities:\n",
      "- Zero-shot prediction: Not supported\n",
      "- Training support: Supported\n"
     ]
    }
   ],
   "source": [
    "# Named Entity Recognition\n",
    "ner_props = PROBLEM_TYPES_REG.get(NER)\n",
    "print_problem_type_info(\"Named Entity Recognition\", ner_props)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "658c7cea",
   "metadata": {},
   "source": [
    "## Feature Extraction\n",
    "\n",
    "Feature extraction converts raw data into meaningful feature vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "92a03f17",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Feature Extraction ===\n",
      "\n",
      "Supported Input Modalities:\n",
      "- image\n",
      "- text\n",
      "\n",
      "Special Capabilities:\n",
      "- Zero-shot prediction: Supported\n",
      "- Training support: Not supported\n"
     ]
    }
   ],
   "source": [
    "# Feature Extraction\n",
    "feature_extraction_props = PROBLEM_TYPES_REG.get(FEATURE_EXTRACTION)\n",
    "print_problem_type_info(\"Feature Extraction\", feature_extraction_props)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09365985",
   "metadata": {},
   "source": [
    "## Few-shot Classification\n",
    "\n",
    "Few-shot classification learns to classify from a small number of examples per class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dc5e0bb9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Few-shot Classification ===\n",
      "\n",
      "Supported Input Modalities:\n",
      "- image\n",
      "- text\n",
      "\n",
      "Evaluation Metrics:\n",
      "- acc\n",
      "- accuracy (default)\n",
      "- balanced_accuracy\n",
      "- f1_macro\n",
      "- f1_micro\n",
      "- f1_weighted\n",
      "- log_loss\n",
      "- mcc\n",
      "- nll\n",
      "- pac\n",
      "- pac_score\n",
      "- precision_macro\n",
      "- precision_micro\n",
      "- precision_weighted\n",
      "- quadratic_kappa\n",
      "- recall_macro\n",
      "- recall_micro\n",
      "- recall_weighted\n",
      "- roc_auc_ovo\n",
      "- roc_auc_ovo_macro\n",
      "- roc_auc_ovo_weighted\n",
      "- roc_auc_ovr\n",
      "- roc_auc_ovr_macro\n",
      "- roc_auc_ovr_micro\n",
      "- roc_auc_ovr_weighted\n",
      "\n",
      "Special Capabilities:\n",
      "- Zero-shot prediction: Not supported\n",
      "- Training support: Supported\n"
     ]
    }
   ],
   "source": [
    "# Few-shot Classification\n",
    "few_shot_props = PROBLEM_TYPES_REG.get(FEW_SHOT_CLASSIFICATION)\n",
    "print_problem_type_info(\"Few-shot Classification\", few_shot_props)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "599f45fc-666e-4603-9226-098f47b2cbf1",
   "metadata": {},
   "source": [
    "## Other Examples\n",
    "You may go to [AutoMM Examples](https://github.com/autogluon/autogluon/tree/master/examples/automm) to explore other examples about AutoMM."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb778c7a",
   "metadata": {},
   "source": [
    "## Customization\n",
    "To learn how to customize AutoMM, please refer to [Customize AutoMM](../advanced_topics/customization.ipynb)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
